{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自动语音识别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "hide_input": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/TksaY_FDgnk?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/TksaY_FDgnk?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "自动语音识别（ASR）将语音信号转换为文本，将一系列音频输入映射到文本输出。\n",
    "Siri 和 Alexa 这类虚拟助手使用 ASR 模型来帮助用户日常生活，还有许多其他面向用户的有用应用，如会议实时字幕和会议纪要。\n",
    "\n",
    "本指南将向您展示如何：\n",
    "\n",
    "1. 在 [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) 数据集上对\n",
    "   [Wav2Vec2](https://huggingface.co/facebook/wav2vec2-base) 进行微调，以将音频转录为文本。\n",
    "2. 使用微调后的模型进行推断。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "如果您想查看所有与本任务兼容的架构和检查点，最好查看[任务页](https://huggingface.co/tasks/automatic-speech-recognition)。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "在开始之前，请确保您已安装所有必要的库：\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate jiwer\n",
    "```\n",
    "\n",
    "我们鼓励您登录自己的 Hugging Face 账户，这样您就可以上传并与社区分享您的模型。\n",
    "出现提示时，输入您的令牌登录："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载 MInDS-14 数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先从🤗 Datasets 库中加载 [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14)\n",
    "数据集的一个较小子集。这将让您有机会先进行实验，确保一切正常，然后再花更多时间在完整数据集上进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "minds = load_dataset(\"PolyAI/minds14\", name=\"en-US\", split=\"train[:100]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 `~Dataset.train_test_split` 方法将数据集的 `train` 拆分为训练集和测试集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minds = minds.train_test_split(test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后看看数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],\n",
       "        num_rows: 16\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],\n",
       "        num_rows: 4\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "虽然数据集包含 `lang_id` 和 `english_transcription` 等许多有用的信息，但在本指南中，\n",
    "您将专注于 `audio` 和 `transcription`。使用 `remove_columns` 方法删除其他列："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "minds = minds.remove_columns([\"english_transcription\", \"intent_class\", \"lang_id\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "再看看示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'array': array([-0.00024414,  0.        ,  0.        , ...,  0.00024414,\n",
       "          0.00024414,  0.00024414], dtype=float32),\n",
       "  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       "  'sampling_rate': 8000},\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       " 'transcription': \"hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "有 2 个字段：\n",
    "\n",
    "- `audio`：由语音信号形成的一维 `array`，用于加载和重新采样音频文件。\n",
    "- `transcription`：目标文本。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下一步是加载一个 Wav2Vec2 处理器来处理音频信号："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"facebook/wav2vec2-base\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MInDS-14 数据集的采样率为 8000kHz（您可以在其[数据集卡片](https://huggingface.co/datasets/PolyAI/minds14)中找到此信息），\n",
    "这意味着您需要将数据集重新采样为 16000kHz 以使用预训练的 Wav2Vec2 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'audio': {'array': array([-2.38064706e-04, -1.58618059e-04, -5.43987835e-06, ...,\n",
       "          2.78103951e-04,  2.38446111e-04,  1.18740834e-04], dtype=float32),\n",
       "  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       "  'sampling_rate': 16000},\n",
       " 'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602ba9e2963e11ccd901cd4f.wav',\n",
       " 'transcription': \"hi I'm trying to use the banking app on my phone and currently my checking and savings account balance is not refreshing\"}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minds = minds.cast_column(\"audio\", Audio(sampling_rate=16_000))\n",
    "minds[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如您在上面的 `transcription` 中所看到的，文本包含大小写字符的混合。\n",
    "Wav2Vec2 分词器仅训练了大写字符，因此您需要确保文本与分词器的词汇表匹配："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def uppercase(example):\n",
    "    return {\"transcription\": example[\"transcription\"].upper()}\n",
    "\n",
    "\n",
    "minds = minds.map(uppercase)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在创建一个预处理函数，该函数应该：\n",
    "\n",
    "1. 调用 `audio` 列以加载和重新采样音频文件。\n",
    "2. 从音频文件中提取 `input_values` 并使用处理器对 `transcription` 列执行 tokenizer 操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_dataset(batch):\n",
    "    audio = batch[\"audio\"]\n",
    "    batch = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"], text=batch[\"transcription\"])\n",
    "    batch[\"input_length\"] = len(batch[\"input_values\"][0])\n",
    "    return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要在整个数据集上应用预处理函数，可以使用🤗 Datasets 的 `map` 函数。\n",
    "您可以通过增加 `num_proc` 参数来加速 `map` 的处理进程数量。\n",
    "使用 `remove_columns` 方法删除不需要的列："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_minds = minds.map(prepare_dataset, remove_columns=minds.column_names[\"train\"], num_proc=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 Transformers 没有用于 ASR 的数据整理器，因此您需要调整 [DataCollatorWithPadding](https://huggingface.co/docs/transformers/main/zh/main_classes/data_collator#transformers.DataCollatorWithPadding) 来创建一个示例批次。\n",
    "它还会动态地将您的文本和标签填充到其批次中最长元素的长度（而不是整个数据集），以使它们具有统一的长度。\n",
    "虽然可以通过在 `tokenizer` 函数中设置 `padding=True` 来填充文本，但动态填充更有效。\n",
    "\n",
    "与其他数据整理器不同，这个特定的数据整理器需要对 `input_values` 和 `labels` 应用不同的填充方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Any, Dict, List, Optional, Union\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataCollatorCTCWithPadding:\n",
    "    processor: AutoProcessor\n",
    "    padding: Union[bool, str] = \"longest\"\n",
    "\n",
    "    def __call__(self, features: list[dict[str, Union[list[int], torch.Tensor]]]) -> dict[str, torch.Tensor]:\n",
    "        # split inputs and labels since they have to be of different lengths and need\n",
    "        # different padding methods\n",
    "        input_features = [{\"input_values\": feature[\"input_values\"][0]} for feature in features]\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "\n",
    "        batch = self.processor.pad(input_features, padding=self.padding, return_tensors=\"pt\")\n",
    "\n",
    "        labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors=\"pt\")\n",
    "\n",
    "        # replace padding with -100 to ignore loss correctly\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在实例化您的 `DataCollatorForCTCWithPadding`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorCTCWithPadding(processor=processor, padding=\"longest\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在训练过程中包含一个指标通常有助于评估模型的性能。\n",
    "您可以通过🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) 库快速加载一个评估方法。\n",
    "对于这个任务，加载 [word error rate](https://huggingface.co/spaces/evaluate-metric/wer)（WER）指标\n",
    "（请参阅🤗 Evaluate [快速上手](https://huggingface.co/docs/evaluate/a_quick_tour)以了解如何加载和计算指标）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "wer = evaluate.load(\"wer\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后创建一个函数，将您的预测和标签传递给 `compute` 来计算 WER："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(pred):\n",
    "    pred_logits = pred.predictions\n",
    "    pred_ids = np.argmax(pred_logits, axis=-1)\n",
    "\n",
    "    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
    "\n",
    "    pred_str = processor.batch_decode(pred_ids)\n",
    "    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
    "\n",
    "    wer = wer.compute(predictions=pred_str, references=label_str)\n",
    "\n",
    "    return {\"wer\": wer}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "您的 `compute_metrics` 函数现在已经准备就绪，当您设置好训练时将返回给此函数。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "如果您不熟悉使用`Trainer`微调模型，请查看这里的基本教程[here](https://huggingface.co/docs/transformers/main/zh/tasks/../training#train-with-pytorch-trainer)！\n",
    "\n",
    "</Tip>\n",
    "\n",
    "现在您已经准备好开始训练您的模型了！使用 `AutoModelForCTC` 加载 Wav2Vec2。\n",
    "使用 `ctc_loss_reduction` 参数指定要应用的减少方式。通常最好使用平均值而不是默认的求和："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForCTC, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForCTC.from_pretrained(\n",
    "    \"facebook/wav2vec2-base\",\n",
    "    ctc_loss_reduction=\"mean\",\n",
    "    pad_token_id=processor.tokenizer.pad_token_id,"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此时，只剩下 3 个步骤：\n",
    "\n",
    "1. 在 `TrainingArguments` 中定义您的训练参数。唯一必需的参数是 `output_dir`，用于指定保存模型的位置。\n",
    "   您可以通过设置 `push_to_hub=True` 将此模型推送到 Hub（您需要登录到 Hugging Face 才能上传您的模型）。\n",
    "   在每个 epoch 结束时，`Trainer` 将评估 WER 并保存训练检查点。\n",
    "2. 将训练参数与模型、数据集、分词器、数据整理器和 `compute_metrics` 函数一起传递给 `Trainer`。\n",
    "3. 调用 `train()` 来微调您的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"my_awesome_asr_mind_model\",\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=2,\n",
    "    learning_rate=1e-5,\n",
    "    warmup_steps=500,\n",
    "    max_steps=2000,\n",
    "    gradient_checkpointing=True,\n",
    "    fp16=True,\n",
    "    group_by_length=True,\n",
    "    eval_strategy=\"steps\",\n",
    "    per_device_eval_batch_size=8,\n",
    "    save_steps=1000,\n",
    "    eval_steps=1000,\n",
    "    logging_steps=25,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"wer\",\n",
    "    greater_is_better=False,\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=encoded_minds[\"train\"],\n",
    "    eval_dataset=encoded_minds[\"test\"],\n",
    "    processing_class=processor,\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练完成后，使用 `push_to_hub()` 方法将您的模型分享到 Hub，方便大家使用您的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "要深入了解如何微调模型进行自动语音识别，\n",
    "请查看这篇博客[文章](https://huggingface.co/blog/fine-tune-wav2vec2-english)以了解英语 ASR，\n",
    "还可以参阅[这篇文章](https://huggingface.co/blog/fine-tune-xlsr-wav2vec2)以了解多语言 ASR。\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推断"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "很好，现在您已经微调了一个模型，您可以用它进行推断了！\n",
    "\n",
    "加载您想要运行推断的音频文件。请记住，如果需要，将音频文件的采样率重新采样为与模型匹配的采样率！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, Audio\n",
    "\n",
    "dataset = load_dataset(\"PolyAI/minds14\", \"en-US\", split=\"train\")\n",
    "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
    "sampling_rate = dataset.features[\"audio\"].sampling_rate\n",
    "audio_file = dataset[0][\"audio\"][\"path\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "尝试使用微调后的模型进行推断的最简单方法是使用 [pipeline()](https://huggingface.co/docs/transformers/main/zh/main_classes/pipelines#transformers.pipeline)。\n",
    "使用您的模型实例化一个用于自动语音识别的 `pipeline`，并将您的音频文件传递给它："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': 'I WOUD LIKE O SET UP JOINT ACOUNT WTH Y PARTNER'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "transcriber = pipeline(\"automatic-speech-recognition\", model=\"stevhliu/my_awesome_asr_minds_model\")\n",
    "transcriber(audio_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "转录结果还不错，但可以更好！尝试用更多示例微调您的模型，以获得更好的结果！\n",
    "\n",
    "</Tip>\n",
    "\n",
    "如果您愿意，您也可以手动复制 `pipeline` 的结果：\n",
    "\n",
    "\n",
    "加载一个处理器来预处理音频文件和转录，并将 `input` 返回为 PyTorch 张量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"stevhliu/my_awesome_asr_mind_model\")\n",
    "inputs = processor(dataset[0][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将您的输入传递给模型并返回 logits："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCTC\n",
    "\n",
    "model = AutoModelForCTC.from_pretrained(\"stevhliu/my_awesome_asr_mind_model\")\n",
    "with torch.no_grad():\n",
    "    logits = model(**inputs).logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "获取具有最高概率的预测 `input_ids`，并使用处理器将预测的 `input_ids` 解码回文本："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['I WOUL LIKE O SET UP JOINT ACOUNT WTH Y PARTNER']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "predicted_ids = torch.argmax(logits, dim=-1)\n",
    "transcription = processor.batch_decode(predicted_ids)\n",
    "transcription"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
